import copy
import os
import sys
import parsee
import torch
import torch.nn as nn
import torch.nn.functional as F

from prepare_data import PrepareData
from model.attention import MultiHeadedAttention
from model.position_wise_feedforward import PositionwiseFeedForward
from model.embedding import PositionalEncoding, Embeddings
from model.transformer import Transformer
from model.encoder import Encoder, EncoderLayer
from model.decoder import Decoder, DecoderLayer
from model.generator import Generator
from lib.criterion import LabelSmoothing
from lib.optimizer import NoamOpt
from train import train
from evaluate import evaluate

def make_model(src_vocab, tgt_vocab, N = 6, d_model = 512, d_ff = 2048, h = 8, dropout = 0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model).to(parsee.device)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(parsee.device)
    position = PositionalEncoding(d_model, dropout).to(parsee.device)
    model = Transformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(parsee.device), N).to(parsee.device),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout).to(parsee.device), N).to(parsee.device),
        nn.Sequential(Embeddings(d_model, src_vocab).to(parsee.device), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab).to(parsee.device), c(position)),
        Generator(d_model, tgt_vocab)).to(parsee.device)
    
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model.to(parsee.device)

def main():
    # 数据预处理
    data = PrepareData()
    parsee.src_vocab = len(data.en_word_dict)
    parsee.tgt_vocab = len(data.cn_word_dict)

    # 初始化模型
    model = make_model(
                        parsee.src_vocab, 
                        parsee.tgt_vocab, 
                        parsee.layers, 
                        parsee.d_model, 
                        parsee.d_ff,
                        parsee.h_num,
                        parsee.dropout
                    )

   
    
    # 训练
    print(">>>>>>> start train")
    criterion = LabelSmoothing(parsees.tgt_vocab, padding_idx = 0, smoothing= 0.0)
    optimizer = NoamOpt(parsee.d_model, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9,0.98), eps=1e-9))
    
    train(data, model, criterion, optimizer)
    print("<<<<<<< finished train")


if __name__ == "__main__":
    main()